{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import warnings\n",
    "warnings.simplefilter('ignore')\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "\n",
    "from simpletransformers.classification import MultiLabelClassificationModel"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "train = pd.read_csv('raw_data/train.csv')\n",
    "test = pd.read_csv('raw_data/nlp_test.csv')\n",
    "sub = pd.read_csv('raw_data/sample_submission.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>ID</th>\n",
       "      <th>category_A</th>\n",
       "      <th>category_B</th>\n",
       "      <th>category_C</th>\n",
       "      <th>category_D</th>\n",
       "      <th>category_E</th>\n",
       "      <th>category_F</th>\n",
       "      <th>Question Sentence</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>0</td>\n",
       "      <td>QC0000000001</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>病情描述：病人是典型的“三高”，想吃拜阿司匹林做为预防用药，但是出现过敏症状。曾经治疗情况和...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1</td>\n",
       "      <td>QC0000000002</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>病情描述：我父亲78岁,小脑梗塞，表现左眼双视，经住院输液治疗现在恢复一月余，现在看人不重影...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2</td>\n",
       "      <td>QC0000000003</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>病情描述：医生您好！我想问一下，我妈脑梗塞今天住院十一天了，情况好转，可以下床活动和清楚表达...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3</td>\n",
       "      <td>QC0000000004</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>停药反跳的血压是身体真实状况吗?</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4</td>\n",
       "      <td>QC0000000005</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>病情描述：医生：银杏叶片生产厂很多，价格差别很大，应怎样选择？曾经治疗情况和效果：头部紧束感...</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "             ID  category_A  category_B  category_C  category_D  category_E  \\\n",
       "0  QC0000000001           0           1           0           0           0   \n",
       "1  QC0000000002           0           1           0           0           0   \n",
       "2  QC0000000003           0           1           0           0           0   \n",
       "3  QC0000000004           1           0           0           0           0   \n",
       "4  QC0000000005           0           1           0           0           0   \n",
       "\n",
       "   category_F                                  Question Sentence  \n",
       "0           0  病情描述：病人是典型的“三高”，想吃拜阿司匹林做为预防用药，但是出现过敏症状。曾经治疗情况和...  \n",
       "1           0  病情描述：我父亲78岁,小脑梗塞，表现左眼双视，经住院输液治疗现在恢复一月余，现在看人不重影...  \n",
       "2           0  病情描述：医生您好！我想问一下，我妈脑梗塞今天住院十一天了，情况好转，可以下床活动和清楚表达...  \n",
       "3           0                                   停药反跳的血压是身体真实状况吗?  \n",
       "4           0  病情描述：医生：银杏叶片生产厂很多，价格差别很大，应怎样选择？曾经治疗情况和效果：头部紧束感...  "
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>ID</th>\n",
       "      <th>Question Sentence</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>0</td>\n",
       "      <td>QCT1000000001</td>\n",
       "      <td>乙肝抗体多久打一次乙肝抗体多久打一次，戊型肝炎的传播途径是怎样的，伴侣中有一个是乙肝携带，不...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1</td>\n",
       "      <td>QCT1000000002</td>\n",
       "      <td>十日观察法我没被狗咬伤，我想问问，狂犬病不是有十日观察法吗，那如果说狗咬人时狗得了狂犬病，可...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2</td>\n",
       "      <td>QCT1000000003</td>\n",
       "      <td>我的乙肝大三阳十年了，现在能治的好了吗沈阳肝病携带者能治好，我的乙肝大三阳十年了，现在能治的...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3</td>\n",
       "      <td>QCT1000000004</td>\n",
       "      <td>治疗丙肝哪里治疗最好丙肝的预防与治疗，治疗丙肝哪里治疗最好，丙肝表现症状是什么呢？丙肝有哪些...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4</td>\n",
       "      <td>QCT1000000005</td>\n",
       "      <td>婚检乙肝上千倍婚检乙肝上千倍，我父亲肝硬化失代偿期，有单个结节癌变。有乙肝病史已经射频手术，...</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "              ID                                  Question Sentence\n",
       "0  QCT1000000001  乙肝抗体多久打一次乙肝抗体多久打一次，戊型肝炎的传播途径是怎样的，伴侣中有一个是乙肝携带，不...\n",
       "1  QCT1000000002  十日观察法我没被狗咬伤，我想问问，狂犬病不是有十日观察法吗，那如果说狗咬人时狗得了狂犬病，可...\n",
       "2  QCT1000000003  我的乙肝大三阳十年了，现在能治的好了吗沈阳肝病携带者能治好，我的乙肝大三阳十年了，现在能治的...\n",
       "3  QCT1000000004  治疗丙肝哪里治疗最好丙肝的预防与治疗，治疗丙肝哪里治疗最好，丙肝表现症状是什么呢？丙肝有哪些...\n",
       "4  QCT1000000005  婚检乙肝上千倍婚检乙肝上千倍，我父亲肝硬化失代偿期，有单个结节癌变。有乙肝病史已经射频手术，..."
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>ID</th>\n",
       "      <th>category_A</th>\n",
       "      <th>category_B</th>\n",
       "      <th>category_C</th>\n",
       "      <th>category_D</th>\n",
       "      <th>category_E</th>\n",
       "      <th>category_F</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>0</td>\n",
       "      <td>QCT1000000001</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1</td>\n",
       "      <td>QCT1000000002</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2</td>\n",
       "      <td>QCT1000000003</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3</td>\n",
       "      <td>QCT1000000004</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4</td>\n",
       "      <td>QCT1000000005</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "              ID  category_A  category_B  category_C  category_D  category_E  \\\n",
       "0  QCT1000000001           1           1           0           0           1   \n",
       "1  QCT1000000002           0           1           0           1           1   \n",
       "2  QCT1000000003           0           1           1           0           1   \n",
       "3  QCT1000000004           1           1           0           1           0   \n",
       "4  QCT1000000005           0           1           0           0           1   \n",
       "\n",
       "   category_F  \n",
       "0           0  \n",
       "1           1  \n",
       "2           0  \n",
       "3           0  \n",
       "4           1  "
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sub.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "((5000, 8), (3000, 2))"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train.shape, test.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "count    5000.000000\n",
       "mean      110.758800\n",
       "std       127.906681\n",
       "min         4.000000\n",
       "25%        42.000000\n",
       "50%        77.000000\n",
       "75%       132.000000\n",
       "max      2093.000000\n",
       "Name: Question Sentence, dtype: float64"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train['Question Sentence'].apply(lambda x: len(x)).describe()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "count    3000.000000\n",
       "mean      144.382667\n",
       "std       224.694973\n",
       "min        19.000000\n",
       "25%        54.000000\n",
       "50%        82.000000\n",
       "75%       134.000000\n",
       "max      3095.000000\n",
       "Name: Question Sentence, dtype: float64"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test['Question Sentence'].apply(lambda x: len(x)).describe()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>text</th>\n",
       "      <th>labels</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>0</td>\n",
       "      <td>病情描述：病人是典型的“三高”，想吃拜阿司匹林做为预防用药，但是出现过敏症状。曾经治疗情况和...</td>\n",
       "      <td>[0, 1, 0, 0, 0, 0]</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1</td>\n",
       "      <td>病情描述：我父亲78岁,小脑梗塞，表现左眼双视，经住院输液治疗现在恢复一月余，现在看人不重影...</td>\n",
       "      <td>[0, 1, 0, 0, 0, 0]</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2</td>\n",
       "      <td>病情描述：医生您好！我想问一下，我妈脑梗塞今天住院十一天了，情况好转，可以下床活动和清楚表达...</td>\n",
       "      <td>[0, 1, 0, 0, 0, 0]</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3</td>\n",
       "      <td>停药反跳的血压是身体真实状况吗?</td>\n",
       "      <td>[1, 0, 0, 0, 0, 0]</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4</td>\n",
       "      <td>病情描述：医生：银杏叶片生产厂很多，价格差别很大，应怎样选择？曾经治疗情况和效果：头部紧束感...</td>\n",
       "      <td>[0, 1, 0, 0, 0, 0]</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                                text              labels\n",
       "0  病情描述：病人是典型的“三高”，想吃拜阿司匹林做为预防用药，但是出现过敏症状。曾经治疗情况和...  [0, 1, 0, 0, 0, 0]\n",
       "1  病情描述：我父亲78岁,小脑梗塞，表现左眼双视，经住院输液治疗现在恢复一月余，现在看人不重影...  [0, 1, 0, 0, 0, 0]\n",
       "2  病情描述：医生您好！我想问一下，我妈脑梗塞今天住院十一天了，情况好转，可以下床活动和清楚表达...  [0, 1, 0, 0, 0, 0]\n",
       "3                                   停药反跳的血压是身体真实状况吗?  [1, 0, 0, 0, 0, 0]\n",
       "4  病情描述：医生：银杏叶片生产厂很多，价格差别很大，应怎样选择？曾经治疗情况和效果：头部紧束感...  [0, 1, 0, 0, 0, 0]"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_df = train[['Question Sentence']].copy()\n",
    "train_df['labels'] = train[['category_A','category_B','category_C','category_D','category_E','category_F']].values.tolist()\n",
    "train_df.columns = ['text', 'labels']\n",
    "\n",
    "train_df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "((4000, 2), (1000, 2))"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "eval_df = train_df[4000:]\n",
    "train_df = train_df[:4000]\n",
    "\n",
    "train_df.shape, eval_df.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "args = {\n",
    "    'num_train_epochs': 5,\n",
    "    'fp16': False,\n",
    "    'no_save': True,\n",
    "    'overwrite_output_dir': True,\n",
    "    'max_seq_length': 128,\n",
    "    'evaluate_during_training': True,\n",
    "    'evaluate_during_training_verbose': True\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of the model checkpoint at hfl/chinese-roberta-wwm-ext were not used when initializing BertForMultiLabelSequenceClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias']\n",
      "- This IS expected if you are initializing BertForMultiLabelSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).\n",
      "- This IS NOT expected if you are initializing BertForMultiLabelSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
      "Some weights of BertForMultiLabelSequenceClassification were not initialized from the model checkpoint at hfl/chinese-roberta-wwm-ext and are newly initialized: ['classifier.weight', 'classifier.bias']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
     ]
    }
   ],
   "source": [
    "model = MultiLabelClassificationModel('bert', \n",
    "                                      'hfl/chinese-roberta-wwm-ext', \n",
    "                                      num_labels=6, \n",
    "                                      args=args)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "dc5c58890d9846cb8cfd8cc77f5e47df",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, max=4000.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "0017c24bbf6b4984a154b04ac88a366d",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Epoch', max=5.0, style=ProgressStyle(description_width='i…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "1ced930869234075b411d49715a77879",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Running Epoch 0 of 5', max=500.0, style=ProgressStyle(des…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "994ac2b4ec7c4ea28c1579474c711a12",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Running Epoch 1 of 5', max=500.0, style=ProgressStyle(des…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "57244a6d36374eb887870c5eb399a87a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Running Epoch 2 of 5', max=500.0, style=ProgressStyle(des…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "1b5799af88974b09947a99a9ec55e16d",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Running Epoch 3 of 5', max=500.0, style=ProgressStyle(des…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "4b1efa1f44df472797f91f6c3c85178e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Running Epoch 4 of 5', max=500.0, style=ProgressStyle(des…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n"
     ]
    }
   ],
   "source": [
    "model.train_model(train_df, eval_df=eval_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "f52c8cd8b0894c43a196978dc87700cc",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, max=3000.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "856eb7b1994e4eb2808b30495a7018c2",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, max=375.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "results = model.predict([text for text in test['Question Sentence']])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>ID</th>\n",
       "      <th>category_A</th>\n",
       "      <th>category_B</th>\n",
       "      <th>category_C</th>\n",
       "      <th>category_D</th>\n",
       "      <th>category_E</th>\n",
       "      <th>category_F</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>0</td>\n",
       "      <td>QCT1000000001</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1</td>\n",
       "      <td>QCT1000000002</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2</td>\n",
       "      <td>QCT1000000003</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3</td>\n",
       "      <td>QCT1000000004</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4</td>\n",
       "      <td>QCT1000000005</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "              ID  category_A  category_B  category_C  category_D  category_E  \\\n",
       "0  QCT1000000001           1           0           0           1           0   \n",
       "1  QCT1000000002           0           1           0           0           0   \n",
       "2  QCT1000000003           0           1           0           1           0   \n",
       "3  QCT1000000004           1           1           0           0           0   \n",
       "4  QCT1000000005           0           1           0           0           0   \n",
       "\n",
       "   category_F  \n",
       "0           0  \n",
       "1           0  \n",
       "2           0  \n",
       "3           0  \n",
       "4           0  "
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sub['category_A'] = np.array(results[0])[:, 0]\n",
    "sub['category_B'] = np.array(results[0])[:, 1]\n",
    "sub['category_C'] = np.array(results[0])[:, 2]\n",
    "sub['category_D'] = np.array(results[0])[:, 3]\n",
    "sub['category_E'] = np.array(results[0])[:, 4]\n",
    "sub['category_F'] = np.array(results[0])[:, 5]\n",
    "\n",
    "sub.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "sub.to_csv('roberta_wwm_ext_baseline.csv', index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
